import io
import torch
import nori2 as nori

from PIL import Image



class ImageNet(torch.utils.data.Dataset):
    def __init__(self, train, transform=None, ws=False, nori_prefix='imagenet'):
        super(ImageNet, self).__init__()
        if '1000' in nori_prefix:
            if train:
                nori_name = 'imagenet.train.nori.list'
            else:
                nori_name = 'imagenet.val.nori.list'
        else:
            if train:
                nori_name = 'imagenet100.train.nori.list'
            else:
                nori_name = 'imagenet100.val.nori.list'
        
        if ws:
            nori_path = '/unsullied/sharefs/g:brain/imagenet/ILSVRC2012/' + nori_name
        else:
            nori_path = '/data/Dataset/ImageNet2012/' + nori_name

        self.f = None  # nori.Fetcher()
        self.f_list = []
        self.transform = transform

        with open(nori_path) as g:
            l = g.readline()
            while l:
                ls = l.split()
                self.f_list.append(ls)
                l = g.readline()
        # print(len(self.f_list))

    def __getitem__(self, idx):
        if self.f is None:
            self.f = nori.Fetcher()

        ls = self.f_list[idx]
        raw_img = Image.open(io.BytesIO(self.f.get(ls[0])))
        if self.transform is not None:
            img = self.transform(raw_img)
        label = int(ls[1])
        return img, label

    def __len__(self):
        return len(self.f_list)
